import torch
import torch.nn.functional as F
import torch.nn as nn
from timm import create_model
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # CIFAR10 图片是3通道的
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)  # 10个类别

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def build_model(args):
    if args.model == 'resnet20':
        model = create_model('resnet20', pretrained=args.pretrained, num_classes=args.output_dim).to(args.device)
    elif args.model == 'resnet50':
        model = create_model('resnet50', pretrained=args.pretrained, num_classes=args.output_dim).to(args.device)
    elif args.model == 'lenet':
        model = LeNet().to(args.device)
    elif args.model == 'vit_base_patch16_224':
        model = create_model('vit_base_patch16_224', pretrained=args.pretrained, num_classes=args.output_dim).to(args.device)
    return model

def load_model(model_path, args):
    model = build_model(args)  
    state_dict = torch.load(model_path)

    # remove 'module.' obtained from DDP
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v 
        else:
            new_state_dict[k] = v

    # 加载修改后的状态字典到模型
    model.load_state_dict(new_state_dict)

    return model